#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import glob, sys
import numpy as np
import xarray as xr
import pandas as pd
import warnings
warnings.filterwarnings('ignore')


def Photo_calculation(lat,N):
    declination = -23.44*np.cos(np.radians((360./365.)*(N+10.)))
    sunrise_hourangle = np.arccos(-np.tan(np.radians(lat))*np.tan(np.radians(declination)))
    Photo = 2.*(24./(2.*np.pi))*sunrise_hourangle
    Photo[np.isnan(Photo) & (declination<0.) & (lat>0.)]=0.
    Photo[np.isnan(Photo) & (declination>0.) & (lat>0.)]=24.
    Photo[np.isnan(Photo) & (declination>0.) & (lat<0.)]=0.
    Photo[np.isnan(Photo) & (declination<0.) & (lat<0.)]=24.
    return Photo

def VPD_calculation(T,TD):
    g0=-2.8365744e03
    g1=-6.028076559e03
    g2=1.954263612e01
    g3=-2.737830188e-02
    g4=1.6261698e-05
    g5=7.0229056e-10
    g6=-1.8680009e-13
    g7=2.7150305
    VPact = np.exp(g0*TD**(-2.)+g1*TD**(-1.)+g2+g3*TD**(1.)+g4*TD**(2.)+g5*TD**(3.)+g6*TD**(4.)+g7*np.log(TD)) # Pa
    VPsat = np.exp(g0*T**(-2.)+g1*T**(-1.)+g2+g3*T**(1.)+g4*T**(2.)+g5*T**(3.)+g6*T**(4.)+g7*np.log(T)) # Pa
    VPD = VPsat-VPact
    return VPD

def iPhoto_calculation(Photo):
    Photo_ul = 11.*3600.
    Photo_ll = 10.*3600.
    iPhoto = (Photo*3600.-Photo_ll)/(Photo_ul-Photo_ll)
    iPhoto[Photo*3600.>=Photo_ul]=1.
    iPhoto[Photo*3600.<=Photo_ll]=0.
    return iPhoto

def iVPD_calculation(VPD):
    VPD_ul = 4100.
    VPD_ll = 900.
    iVPD = 1-(VPD-VPD_ll)/(VPD_ul-VPD_ll)
    iVPD[VPD>=VPD_ul]=0.
    iVPD[VPD<=VPD_ll]=1. 
    return iVPD

def iTmin_calculation(Tmin):
    Tmin_ul = 5.
    Tmin_ll = -2.
    iTmin = (Tmin-273.15-Tmin_ll)/(Tmin_ul-Tmin_ll)
    iTmin[(Tmin-273.15)>=Tmin_ul]=1.
    iTmin[(Tmin-273.15)<=Tmin_ll]=0.
    return iTmin

def GSI_calculation(iPhoto,iVPD,iTmin):
    GSI = iPhoto*iVPD*iTmin
    return GSI


year=int(sys.argv[1])     # Enter year
month=int(sys.argv[2])    # Enter month
month_str = ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12']

model=str(sys.argv[3])
experiment=str(sys.argv[4])

times_decode = False

path_era5='../DATA/ERA5/'
ds_t2 = xr.open_dataset(path_era5+'T2m/ERA5_2m_temperature_{}-{}.nc'.format(year,month_str[month]))                          # 2m air temperature in Kelvin
ds_dp = xr.open_dataset(path_era5+'Td2m/ERA5_2m_dewpoint_temperature_{}-{}.nc'.format(year,month_str[month]))                # 2m dew point temperature in Kelvin

path_cmip6='../DATA/CMIP6/'
ds_t2_anom = xr.open_dataset(path_cmip6+model+'/tas/Anom_'+experiment+'/ERA5_tas_'+model+'_'+experiment+'_{}-{}.nc'.format(year,month_str[month]),decode_times=times_decode)             # 2m air temperature in Kelvin
ds_tmin_anom = xr.open_dataset(path_cmip6+model+'/tasmin/Anom_'+experiment+'/ERA5_tasmin_'+model+'_'+experiment+'_{}-{}.nc'.format(year,month_str[month]),decode_times=times_decode)     # 2m air temperature in Kelvin
ds_dp_anom = xr.open_dataset(path_cmip6+model+'/tdps/Anom_'+experiment+'/ERA5_tdps_'+model+'_'+experiment+'_{}-{}.nc'.format(year,month_str[month]),decode_times=times_decode)   # 2m dew point temperature in Kelvin

path_out_gsi='../OUTS/GSI/'+model+'/'+experiment+'/'

os.system('mkdir -p '+path_out_gsi)

t2 = ds_t2 
dp = ds_dp
tmin = ds_t2.resample(time='1D').min()
del ds_t2, ds_dp
ds_t2_anom['time']=tmin['time']
ds_dp_anom['time']=tmin['time']

T2 = t2.t2m+ds_t2_anom.tas.reindex_like(t2).ffill('time')
DP = dp.d2m+ds_dp_anom.tas.reindex_like(dp).ffill('time')
Tmin = tmin.t2m.values+ds_tmin_anom.tasmin.values
del t2, dp, ds_t2_anom, ds_dp_anom, ds_tmin_anom

df = pd.DataFrame(dict(dates=tmin.time))
days = df['dates'].dt.dayofyear.values
N, lat, lon = np.meshgrid(days,tmin.t2m.latitude,tmin.t2m.longitude,indexing='ij')
del lon

Photo = Photo_calculation(lat,N)
del N, lat
VPD = VPD_calculation(T2,DP)
del T2, DP
VPD = VPD.resample(time='1D').max(dim='time')
iPhoto = iPhoto_calculation(Photo)
del Photo
iVPD = iVPD_calculation(VPD.values)
del VPD
iTmin = iTmin_calculation(Tmin)
del Tmin
GSI = GSI_calculation(iPhoto,iVPD,iTmin)
del iPhoto, iVPD, iTmin

ds_gsi = xr.DataArray(np.zeros((tmin.t2m.shape[0],tmin.t2m.shape[1],tmin.t2m.shape[2])), 
    dims=['time', 'latitude', 'longitude'], 
    coords={'time': tmin.time, 'latitude': tmin.latitude.values, 'longitude': tmin.longitude.values},
    name='GSI')
ds_gsi.name = 'GSI'
ds_gsi.attrs['standard_name'] = 'GSI'
ds_gsi.attrs['long_name'] = 'Growing Season Index'
ds_gsi.attrs['units'] = '1'
ds_gsi.data = GSI.copy()
ds_gsi.to_netcdf(path_out_gsi+'ERA5_gsi_{}-{}.nc'.format(year,month_str[month]))
